# -*- coding: utf-8 -*-
"""
Created on Fri Jan 17 14:21:13 2020

@author: gxjco
"""
import numpy as np
import sklearn.metrics
from sklearn.preprocessing import KBinsDiscretizer

def discrete_mutual_info(mus, ys,num_bin):
  """Compute discrete mutual information."""
  num_codes = mus.shape[0]
  num_factors = ys.shape[0]
  m = np.zeros([num_codes, num_factors])
  for i in range(num_codes):
    for j in range(num_factors):
          m[i, j] = sklearn.metrics.mutual_info_score(ys[j, :],  make_discretizer(mus[i, :],num_bin))
  return m

def discrete_entropy(ys,num_bin):
  """Compute discrete entropy of the factors."""
  num_factors = ys.shape[0]
  h = np.zeros(num_factors)
  for j in range(num_factors):
    h[j] = sklearn.metrics.mutual_info_score(make_discretizer(ys[j, :],num_bin),make_discretizer(ys[j, :],num_bin))
  return h

def make_discretizer(target, num_bins):
    """Wrapper that creates discretizers."""
    Dis=KBinsDiscretizer(num_bins, encode='ordinal').fit(target.reshape(-1,1))

    return Dis.transform(target.reshape(-1,1)).reshape(-1)

def MIG_compute(file, prop):
  for j in [0,1,2]:
      m_s = []
      for i in range(100):
        num_bin=20
        factor=np.transpose(np.load(prop+'_z_'+file+'.npy')[:,i].reshape(-1,1)[:,0].reshape(-1,1))
        code=np.transpose(np.load(prop+'_'+file+'.npy')[:,j].reshape(-1,1))
        for i in range(code.shape[0]):
            mut_score = sklearn.metrics.normalized_mutual_info_score(make_discretizer(factor,num_bin), make_discretizer(code[i,:].reshape(1,-1),num_bin))
            m_s.append(mut_score)
        #print ('z'+str(i),m_s)
      print('z_avg',np.mean(m_s[3:]))
      print('z_weight',m_s[0])
      print('z_logp',m_s[1])
      print('z_logs',m_s[2])



if __name__ == '__main__':
    MIG_compute('')
